{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EEShBncAEcF_"
      },
      "source": [
        "# Model Composition using DeepMind CALM and Gemma\n",
        "\n",
        "Welcome to this step-by-step guide on fine-tuning the [Gemma](https://huggingface.co/google/gemma-2b) using [Hugging Face Transformers](https://huggingface.co/docs/transformers/en/index) and DeepMind's **CALM (Composition to Augment Language Models)** framework.\n",
        "\n",
        "As Large Language Models (LLMs) grow ever larger and more capable, it can be both challenging and costly to extend or adapt them to new domains or tasks. Many solutions involve retraining or fine-tuning a large, general-purpose model on new data—a time-consuming and resource-intensive process. Moreover, organizational constraints or data privacy concerns may limit access to the original training data needed for such adaptation.\n",
        "\n",
        "[**CALM**](https://github.com/google-deepmind/calm) addresses these challenges by enabling the composition of two distinct language models—an “anchor” model with foundational capabilities and an “augmenting” model specialized in a particular domain—without fully re-training the anchor model. CALM does this by introducing cross-attention between models, allowing you to combine their strengths and preserve their original capabilities. The result is a more capable composed model that leverages existing, proven models and a few additional parameters, rather than building new monolithic models from scratch. The library currently supports combining any two models built with the Gemma architecture.\n",
        "\n",
        "[**Transformers**](https://huggingface.co/docs/transformers/en/index) is a powerful and versatile tool for working with a wide range of large language models, tokenizers, and pipelines. It offers a user-friendly API for loading, training, and deploying state-of-the-art models, making it an integral component of the machine learning and natural language processing ecosystem. Its broad compatibility, ease of use, and extensive documentation help streamline tasks like fine-tuning, inference, and evaluation of models.\n",
        "\n",
        "[**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
        "\n",
        "In this notebook, you'll learn how to fine-tune a Composed LLM (CALM) configuration using the `gemma-2-2b` model as both the anchor and augmentation model. The resulting composed model merges capabilities from both instances of `gemma-2-2b`. You could try other combinations out too (`9B` with `2B`), but you'll be keeping it simple with just the `2B` Gemma variant.\n",
        "\n",
        "What you'll learn:\n",
        "1. **Setup & Dependencies**: Installing and importing necessary libraries.\n",
        "2. **Model & Configuration**: Initializing the CALM configuration with `gemma-2-2b` as anchor and augmentation models.\n",
        "3. **Data Loading & Preprocessing**: Using an instruction tuning dataset, tokenizing it, and preparing for fine-tuning.\n",
        "4. **Training**: Setting training arguments and running the fine-tuning using the Hugging Face `Trainer`.\n",
        "5. **Saving & Conclusion**: Saving the fine-tuned model.\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_CALM.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JnLf0nPHFHQH"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Selecting the Runtime Environment\n",
        "\n",
        "To start, you can choose either **Google Colab** as your platform.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "\n",
        "  1. Click **Open in Colab**.\n",
        "  2. You'll need access to a [**Colab Pro/Pro+**](https://colab.research.google.com/signup) runtime with sufficient resources to run the Gemma model.\n",
        "  3. In the menu, go to **Runtime** > **Change runtime type**.\n",
        "  4. Ensure that the **GPU** is set to **A100**.\n",
        "\n",
        "### Gemma using Hugging Face\n",
        "\n",
        "Before diving into the tutorial, let's set up Gemma:\n",
        "\n",
        "1. **Create a Hugging Face Account**: If you don't have one, you can sign up for a free account [here](https://huggingface.com/join).\n",
        "2. **Access the Gemma Model**: Visit the [Gemma model page](https://huggingface.com/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage conditions.\n",
        "3. **Generate a Hugging Face Token**: Go to your Hugging Face [settings page](https://huggingface.com/settings/tokens) and generate a new access token (preferably with `write` permissions). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "configure-credentials"
      },
      "source": [
        "### Configure Your Credentials\n",
        "\n",
        "To access private models and datasets, you need to log in to the Hugging Face (HF) ecosystem.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "  If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets manager:\n",
        "  1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "  2. **Add Hugging Face Token**:\n",
        "    - Create a new secret with the **name** `HF_TOKEN`.\n",
        "    - Copy/paste your token key into the **Value** input box of `HF_TOKEN`.\n",
        "    - **Toggle** the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CtjS_1XdGMmg"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "if 'google.colab' in sys.modules:\n",
        "    # Running on Colab\n",
        "    from google.colab import userdata\n",
        "    os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
        "else:\n",
        "    # Not running on Colab\n",
        "    raise EnvironmentError('This notebook is designed to run on Google Colab.')\n",
        "\n",
        "# Disable tokenizers parallelism to avoid deadlocks\n",
        "os.environ['TOKENIZERS_PARALLELISM'] = 'false'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o94dIwSLGOuH"
      },
      "source": [
        "### Setting Up the Environment\n",
        "\n",
        "Next, you'll set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c187fc61-2841-4b2e-b662-5e3d69e1804f"
      },
      "outputs": [],
      "source": [
        "# Clone DeepMind CALM\n",
        "!git clone https://github.com/google-deepmind/calm.git\n",
        "%cd calm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "acOeyDWO9JCP"
      },
      "source": [
        "You will clone the **CALM** repository and install compatible versions of `transformers`, `datasets`, and `accelerate`.\n",
        "\n",
        "**Note**: You are using pinned versions to ensure compatibility. You may adjust them as new updates become available."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8aPUc0YoEQ0i"
      },
      "outputs": [],
      "source": [
        "# Install the appropriate Hugging Face libraries to ensure compatibility with the Gemma 2 model and CALM.\n",
        "!pip install transformers==4.47.0 -U -q\n",
        "!pip install datasets==3.2.0 -U -q\n",
        "!pip install accelerate==1.2.1 -U -q"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mFa1sCpxEQ0i"
      },
      "source": [
        "## Import the libraries\n",
        "\n",
        "\n",
        "You import the required libraries here for loading and preprocessing the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, tokenization, model configuration, and initialization utilities.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oqTKmWpXdsTL"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers import (\n",
        "    AutoTokenizer,\n",
        "    AutoConfig,\n",
        "    AutoModel\n",
        ")\n",
        "# Import the \"calm\" module from the \"model\" package for inference\n",
        "from model import calm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZUdqL3l9G1N-"
      },
      "source": [
        "## Fine-tune using CALM\n",
        "\n",
        "In this section, you will:\n",
        "1. Load a small dataset (`Abirate/english_quotes`) from Hugging Face Datasets.\n",
        "2. Configure CALM by specifying both the anchor (base) model and the augmentation model. In this demonstration, you will use the same `gemma-2-2b` model for both. However, in practice, you may choose a different variant (e.g., `9B`, `27B`) to combine different capabilities.\n",
        "3. Preprocess the dataset for language modeling.\n",
        "4. Use the `Trainer` from Hugging Face Transformers to fine-tune the composed model.\n",
        "\n",
        "You'll save your training logic into a separate Python script (`train.py`) for clarity."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ikNyeh7IEQ0i"
      },
      "source": [
        "### The Training Script\n",
        "\n",
        "The script below:\n",
        "- Sets up the CALM model configuration.\n",
        "- Loads and tokenizes the dataset.\n",
        "- Defines training arguments and runs the training.\n",
        "- Saves the fine-tuned model.\n",
        "\n",
        "You will specify parameters like `anchor_model_dir`, `aug_model_dir`, `num_heads`, `num_connections`, and other hyperparameters via command-line flags. You can easily adjust these flags for different experiments.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JaB828ooRTbf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Overwriting train.py\n"
          ]
        }
      ],
      "source": [
        "%%writefile train.py\n",
        "from collections.abc import Sequence\n",
        "from absl import app\n",
        "from absl import flags\n",
        "from absl import logging\n",
        "\n",
        "import datasets\n",
        "\n",
        "# Import the \"calm\" module from the \"model\" package.\n",
        "# This presumably contains the CALM model implementation\n",
        "from model import calm\n",
        "\n",
        "from transformers import (\n",
        "    AutoTokenizer,\n",
        "    AutoConfig,\n",
        "    AutoModel,\n",
        "    DataCollatorForLanguageModeling,\n",
        "    Trainer,\n",
        "    TrainingArguments,\n",
        ")\n",
        "\n",
        "# Register the custom CALMConfig class under the identifier \"calm\" with AutoConfig.\n",
        "# By doing this, when you specify a configuration type as \"calm\", AutoConfig knows\n",
        "# to use calm.CALMConfig to instantiate the configuration.\n",
        "AutoConfig.register(\"calm\", calm.CALMConfig)\n",
        "\n",
        "# Register the CALM model class with AutoModel for the CALMConfig configuration class.\n",
        "# This means that if AutoModel is given a CALMConfig, it knows to instantiate calm.CALM.\n",
        "AutoModel.register(calm.CALMConfig, calm.CALM)\n",
        "\n",
        "_ANCHOR_MODEL_DIR = flags.DEFINE_string('anchor_model_dir', None, 'Path to the anchor model directory or identifier.')\n",
        "_AUG_MODEL_DIR = flags.DEFINE_string('aug_model_dir', None, 'Path to the augmentation model directory or identifier.')\n",
        "_OUTPUT_DIR = flags.DEFINE_string('output_dir', None, 'Directory where the fine-tuned model will be saved.')\n",
        "_LEARNING_RATE = flags.DEFINE_float('learning_rate', 2e-5, 'Learning rate for fine-tuning.')\n",
        "_EPOCHS = flags.DEFINE_integer('epochs', 3, 'Number of training epochs.')\n",
        "_BATCH_SIZE = flags.DEFINE_integer('batch_size', 1, 'Batch size per device.')\n",
        "_NUM_HEADS = flags.DEFINE_integer('num_heads', 1, 'Number of cross-attention heads in CALM.')\n",
        "_NUM_CONNECTIONS = flags.DEFINE_integer('num_connections', 2, 'Number of cross-connections between anchor and aug models.')\n",
        "_LOGGING_STEPS = flags.DEFINE_integer('logging_steps', 1, 'Logging frequency in steps.')\n",
        "_MAX_STEPS = flags.DEFINE_integer('max_steps', -1, 'Max training steps, use -1 for no limit.')\n",
        "\n",
        "def train(argv: Sequence[str]) -> None:\n",
        "    del argv  # Unused.\n",
        "    SEED = 42\n",
        "\n",
        "    anchor_model_path = _ANCHOR_MODEL_DIR.value\n",
        "    aug_model_path = _AUG_MODEL_DIR.value\n",
        "    num_heads = _NUM_HEADS.value\n",
        "    num_connections = _NUM_CONNECTIONS.value\n",
        "\n",
        "    logging.info('Using anchor model: %s', anchor_model_path)\n",
        "    logging.info('Using augmentation model: %s', aug_model_path)\n",
        "\n",
        "    # Load the tokenizer from the anchor model\n",
        "    logging.info('Loading Tokenizer...')\n",
        "    tokenizer = AutoTokenizer.from_pretrained(anchor_model_path)\n",
        "    tokenizer.padding_side = 'right'\n",
        "\n",
        "    # Create CALM config\n",
        "    logging.info('Creating CALM configuration...')\n",
        "    calm_config = calm.CALMConfig(\n",
        "        anchor_model=anchor_model_path,\n",
        "        aug_model=aug_model_path,\n",
        "        anchor_config=None,\n",
        "        aug_config=None,\n",
        "        num_connections=num_connections,\n",
        "        num_heads=num_heads,\n",
        "    )\n",
        "    calm_config.save_pretrained('./calm_config')\n",
        "\n",
        "    # Initialize the composed CALM model\n",
        "    logging.info('Initializing the CALM model...')\n",
        "    model = calm.CALM(calm_config)\n",
        "    model.config.use_cache = False\n",
        "\n",
        "    # Load the dataset (english_quotes)\n",
        "    logging.info('Loading and preparing dataset...')\n",
        "    dataset = datasets.load_dataset('Abirate/english_quotes', split='all')\n",
        "\n",
        "    # Filter out empty quotes\n",
        "    dataset = dataset.filter(lambda x: len(x[\"quote\"]) > 0)\n",
        "\n",
        "    # For demonstration, use a small subset (e.g., 2048 samples)\n",
        "    dataset = dataset.shuffle(seed=SEED).select(range(2048))\n",
        "\n",
        "    # Tokenize the data\n",
        "    def preprocess_function(examples):\n",
        "        return tokenizer(\n",
        "            examples['quote'], truncation=True, padding='max_length',\n",
        "            max_length=512\n",
        "        )\n",
        "\n",
        "    dataset = dataset.map(preprocess_function, batched=True)\n",
        "\n",
        "    # Data collator for language modeling (no masking since it's causal LM)\n",
        "    data_collator = DataCollatorForLanguageModeling(\n",
        "        tokenizer=tokenizer, mlm=False\n",
        "    )\n",
        "\n",
        "    epochs = _EPOCHS.value\n",
        "    batch_size = _BATCH_SIZE.value\n",
        "    learning_rate = _LEARNING_RATE.value\n",
        "    output_dir = _OUTPUT_DIR.value\n",
        "    logging_steps = _LOGGING_STEPS.value\n",
        "    max_steps = _MAX_STEPS.value\n",
        "\n",
        "    # Split into train/validation sets\n",
        "    dataset = dataset.train_test_split(test_size=0.02)\n",
        "\n",
        "    # TrainingArguments for Hugging Face Trainer\n",
        "    training_args = TrainingArguments(\n",
        "        output_dir=output_dir,\n",
        "        save_strategy='no',\n",
        "        overwrite_output_dir=True,\n",
        "        num_train_epochs=epochs,\n",
        "        per_device_train_batch_size=batch_size,\n",
        "        per_device_eval_batch_size=batch_size,\n",
        "        eval_strategy='epoch',\n",
        "        optim=\"adamw_torch_fused\",\n",
        "        lr_scheduler_type=\"constant\",\n",
        "        warmup_ratio=0.03,\n",
        "        logging_steps=logging_steps,\n",
        "        max_steps=max_steps,\n",
        "        learning_rate=learning_rate,\n",
        "        report_to=\"none\",\n",
        "        seed=SEED,\n",
        "    )\n",
        "\n",
        "    # Initialize the Trainer\n",
        "    trainer = Trainer(\n",
        "        model=model,\n",
        "        args=training_args,\n",
        "        train_dataset=dataset['train'],\n",
        "        eval_dataset=dataset['test'],\n",
        "        data_collator=data_collator,\n",
        "        tokenizer=tokenizer,\n",
        "    )\n",
        "\n",
        "    # Train the model\n",
        "    logging.info('Starting training...')\n",
        "    trainer.can_return_loss = True\n",
        "    trainer.train()\n",
        "    trainer.save_model(output_dir)\n",
        "    print(f'Training complete! Model saved to {output_dir}')\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    app.run(train)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZNQnSYThEQ0k"
      },
      "source": [
        "### Start fine-tuning\n",
        "\n",
        "You can now fine-tune the composed model. To do this, you'll run a short training run of only 50 steps for demonstration. For a real training job, consider increasing the `max_steps` or `epochs` and using a larger dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B4gWTxqIEQ0k"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2024-12-17 14:15:13.357550: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
            "2024-12-17 14:15:13.375429: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
            "2024-12-17 14:15:13.396556: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
            "2024-12-17 14:15:13.403038: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
            "2024-12-17 14:15:13.418338: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
            "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
            "2024-12-17 14:15:14.712618: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
            "I1217 14:15:16.050302 138709282697856 train.py:50] Using anchor model: google/gemma-2-2b\n",
            "I1217 14:15:16.050770 138709282697856 train.py:51] Using augmentation model: google/gemma-2-2b\n",
            "I1217 14:15:16.050826 138709282697856 train.py:54] Loading Tokenizer...\n",
            "I1217 14:15:17.317491 138709282697856 train.py:59] Creating CALM configuration...\n",
            "I1217 14:15:17.318757 138709282697856 train.py:71] Initializing the CALM model...\n",
            "CALM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
            "  - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n",
            "  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
            "  - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n",
            "Loading checkpoint shards: 100% 3/3 [00:00<00:00,  3.55it/s]\n",
            "Loading checkpoint shards: 100% 3/3 [00:00<00:00,  3.56it/s]\n",
            "I1217 14:15:20.079115 138709282697856 train.py:76] Loading and preparing dataset...\n",
            "/content/calm/train.py:128: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
            "  trainer = Trainer(\n",
            "I1217 14:15:28.227010 138709282697856 train.py:138] Starting training...\n",
            "  0% 0/50 [00:00<?, ?it/s]The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.\n",
            "The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.\n",
            "{'loss': 4.0153, 'grad_norm': 20.442930221557617, 'learning_rate': 3e-05, 'epoch': 0.0}\n",
            "{'loss': 3.381, 'grad_norm': 97.37456512451172, 'learning_rate': 3e-05, 'epoch': 0.0}\n",
            "{'loss': 3.6812, 'grad_norm': 32.1188850402832, 'learning_rate': 3e-05, 'epoch': 0.0}\n",
            "{'loss': 2.1305, 'grad_norm': 12.197807312011719, 'learning_rate': 3e-05, 'epoch': 0.0}\n",
            "{'loss': 2.9738, 'grad_norm': 15.338357925415039, 'learning_rate': 3e-05, 'epoch': 0.0}\n",
            "{'loss': 3.3701, 'grad_norm': 64.96607208251953, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 1.5749, 'grad_norm': 8.17130184173584, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 2.6055, 'grad_norm': 32.14103317260742, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 1.4701, 'grad_norm': 14.105799674987793, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 1.7774, 'grad_norm': 13.641266822814941, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 2.2608, 'grad_norm': 18.464675903320312, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 1.008, 'grad_norm': 49.01253890991211, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 3.2358, 'grad_norm': 26.21762466430664, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 2.751, 'grad_norm': 6.75759220123291, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 2.4693, 'grad_norm': 8.144367218017578, 'learning_rate': 3e-05, 'epoch': 0.01}\n",
            "{'loss': 3.3663, 'grad_norm': 12.402542114257812, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 3.713, 'grad_norm': 12.237486839294434, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 2.2904, 'grad_norm': 14.829988479614258, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 1.2942, 'grad_norm': 8.686445236206055, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 2.9049, 'grad_norm': 32.51576232910156, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 2.9112, 'grad_norm': 37.358924865722656, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 2.9893, 'grad_norm': 8.115110397338867, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 2.2961, 'grad_norm': 6.279969215393066, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 3.096, 'grad_norm': 9.917762756347656, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 1.7747, 'grad_norm': 5.083250999450684, 'learning_rate': 3e-05, 'epoch': 0.02}\n",
            "{'loss': 2.2717, 'grad_norm': 39.76508331298828, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 2.3121, 'grad_norm': 7.2948689460754395, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 0.858, 'grad_norm': 32.15620803833008, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 2.8447, 'grad_norm': 18.484107971191406, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 2.243, 'grad_norm': 8.482945442199707, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 2.2743, 'grad_norm': 6.612159252166748, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 1.0838, 'grad_norm': 17.88753890991211, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 1.9032, 'grad_norm': 19.264760971069336, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 2.4416, 'grad_norm': 28.318525314331055, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 2.9863, 'grad_norm': 6.888128757476807, 'learning_rate': 3e-05, 'epoch': 0.03}\n",
            "{'loss': 3.0216, 'grad_norm': 8.766512870788574, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 1.691, 'grad_norm': 94.17549133300781, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 0.8048, 'grad_norm': 18.521203994750977, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 1.5768, 'grad_norm': 11.479368209838867, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 3.1555, 'grad_norm': 9.045062065124512, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 1.1065, 'grad_norm': 18.796716690063477, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 2.3622, 'grad_norm': 7.588309288024902, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 2.7275, 'grad_norm': 22.93894386291504, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 0.7593, 'grad_norm': 13.026911735534668, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 1.2647, 'grad_norm': 9.585477828979492, 'learning_rate': 3e-05, 'epoch': 0.04}\n",
            "{'loss': 2.5192, 'grad_norm': 11.320117950439453, 'learning_rate': 3e-05, 'epoch': 0.05}\n",
            "{'loss': 3.1441, 'grad_norm': 6.723537445068359, 'learning_rate': 3e-05, 'epoch': 0.05}\n",
            "{'loss': 1.9326, 'grad_norm': 17.888809204101562, 'learning_rate': 3e-05, 'epoch': 0.05}\n",
            "{'loss': 0.7309, 'grad_norm': 8.853791236877441, 'learning_rate': 3e-05, 'epoch': 0.05}\n",
            "{'loss': 2.4107, 'grad_norm': 9.26325511932373, 'learning_rate': 3e-05, 'epoch': 0.05}\n",
            "100% 50/50 [01:00<00:00,  1.19s/it]\n",
            "  0% 0/21 [00:00<?, ?it/s]\u001b[A\n",
            " 10% 2/21 [00:00<00:06,  2.77it/s]\u001b[A\n",
            " 14% 3/21 [00:01<00:09,  1.95it/s]\u001b[A\n",
            " 19% 4/21 [00:02<00:10,  1.69it/s]\u001b[A\n",
            " 24% 5/21 [00:02<00:10,  1.57it/s]\u001b[A\n",
            " 29% 6/21 [00:03<00:09,  1.50it/s]\u001b[A\n",
            " 33% 7/21 [00:04<00:09,  1.46it/s]\u001b[A\n",
            " 38% 8/21 [00:05<00:09,  1.44it/s]\u001b[A\n",
            " 43% 9/21 [00:05<00:08,  1.42it/s]\u001b[A\n",
            " 48% 10/21 [00:06<00:07,  1.41it/s]\u001b[A\n",
            " 52% 11/21 [00:07<00:07,  1.40it/s]\u001b[A\n",
            " 57% 12/21 [00:07<00:06,  1.39it/s]\u001b[A\n",
            " 62% 13/21 [00:08<00:05,  1.39it/s]\u001b[A\n",
            " 67% 14/21 [00:09<00:05,  1.39it/s]\u001b[A\n",
            " 71% 15/21 [00:10<00:04,  1.39it/s]\u001b[A\n",
            " 76% 16/21 [00:10<00:03,  1.38it/s]\u001b[A\n",
            " 81% 17/21 [00:11<00:02,  1.38it/s]\u001b[A\n",
            " 86% 18/21 [00:12<00:02,  1.38it/s]\u001b[A\n",
            " 90% 19/21 [00:13<00:01,  1.38it/s]\u001b[A\n",
            " 95% 20/21 [00:13<00:00,  1.38it/s]\u001b[A\n",
            "                                   \n",
            "\u001b[A{'eval_loss': 2.069748640060425, 'eval_runtime': 14.8316, 'eval_samples_per_second': 2.764, 'eval_steps_per_second': 1.416, 'epoch': 0.05}\n",
            "100% 50/50 [01:15<00:00,  1.19s/it]\n",
            "100% 21/21 [00:14<00:00,  1.54it/s]\u001b[A\n",
            "{'train_runtime': 75.2186, 'train_samples_per_second': 1.329, 'train_steps_per_second': 0.665, 'train_loss': 2.31533754825592, 'epoch': 0.05}\n",
            "100% 50/50 [01:15<00:00,  1.50s/it]\n",
            "Training complete! Model saved to ./gemma-ft\n"
          ]
        }
      ],
      "source": [
        "anchor_model_path = 'google/gemma-2-2b'\n",
        "aug_model_path = 'google/gemma-2-2b'\n",
        "\n",
        "# Remove previous output directory if exists\n",
        "!rm -rf ./gemma-ft\n",
        "\n",
        "# Run training with specified parameters\n",
        "!python train.py --anchor_model_dir google/gemma-2-2b \\\n",
        "          --aug_model_dir google/gemma-2-2b \\\n",
        "          --num_heads 2 \\\n",
        "          --num_connections 2 \\\n",
        "          --learning_rate 3e-5 \\\n",
        "          --batch_size 2 \\\n",
        "          --max_steps 50 \\\n",
        "          --output_dir './gemma-ft'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GDbHxwfPHjc5"
      },
      "source": [
        "## Prompt using the newly fine-tuned model\n",
        "\n",
        "Let's finally prompt using the fine-tuned model and also verify if it's really working as intended. To do this, let's test the model with a sample prompt by first using the tokenizer to generate the input ids, and then rely on the reloaded fine-tuned model to generate a response using `model.generate()`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f-givDOAl3zA"
      },
      "outputs": [],
      "source": [
        "# Register the custom CALMConfig and CALM classes with AutoConfig and AutoModel\n",
        "AutoConfig.register(\"calm\", calm.CALMConfig)\n",
        "AutoModel.register(calm.CALMConfig, calm.CALM)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1fTSl1WpH4GF"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "CALM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
            "  - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n",
            "  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
            "  - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "860a03d2c70943b6b152441708658106",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2327f921e0b8458699a399d81c3e4c70",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "997a1b5f905a46b4bb1e5668ffb0539a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "CALM(\n",
              "  (anchor_model): Gemma2ForCausalLM(\n",
              "    (model): Gemma2Model(\n",
              "      (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n",
              "      (layers): ModuleList(\n",
              "        (0-25): 26 x Gemma2DecoderLayer(\n",
              "          (self_attn): Gemma2Attention(\n",
              "            (q_proj): Linear(in_features=2304, out_features=2048, bias=False)\n",
              "            (k_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
              "            (v_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
              "            (o_proj): Linear(in_features=2048, out_features=2304, bias=False)\n",
              "            (rotary_emb): Gemma2RotaryEmbedding()\n",
              "          )\n",
              "          (mlp): Gemma2MLP(\n",
              "            (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
              "            (up_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
              "            (down_proj): Linear(in_features=9216, out_features=2304, bias=False)\n",
              "            (act_fn): PytorchGELUTanh()\n",
              "          )\n",
              "          (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "          (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "          (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "          (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "        )\n",
              "      )\n",
              "      (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "    )\n",
              "    (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n",
              "  )\n",
              "  (aug_model): Gemma2ForCausalLM(\n",
              "    (model): Gemma2Model(\n",
              "      (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n",
              "      (layers): ModuleList(\n",
              "        (0-25): 26 x Gemma2DecoderLayer(\n",
              "          (self_attn): Gemma2Attention(\n",
              "            (q_proj): Linear(in_features=2304, out_features=2048, bias=False)\n",
              "            (k_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
              "            (v_proj): Linear(in_features=2304, out_features=1024, bias=False)\n",
              "            (o_proj): Linear(in_features=2048, out_features=2304, bias=False)\n",
              "            (rotary_emb): Gemma2RotaryEmbedding()\n",
              "          )\n",
              "          (mlp): Gemma2MLP(\n",
              "            (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
              "            (up_proj): Linear(in_features=2304, out_features=9216, bias=False)\n",
              "            (down_proj): Linear(in_features=9216, out_features=2304, bias=False)\n",
              "            (act_fn): PytorchGELUTanh()\n",
              "          )\n",
              "          (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "          (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "          (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "          (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "        )\n",
              "      )\n",
              "      (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
              "    )\n",
              "    (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n",
              "  )\n",
              "  (cross_attention_hooks): ModuleList(\n",
              "    (0-1): 2 x CrossAttentionHook(\n",
              "      (proj): Linear(in_features=2304, out_features=2304, bias=True)\n",
              "      (post_attention_layernorm): GemmaRMSNorm((2304,), eps=1e-06)\n",
              "      (cross_attention): MultiheadAttention(\n",
              "        (out_proj): NonDynamicallyQuantizableLinear(in_features=2304, out_features=2304, bias=True)\n",
              "      )\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Load the CALM configuration\n",
        "config = calm.CALMConfig.from_pretrained('./calm_config')\n",
        "\n",
        "# Load the composed and fine-tuned model\n",
        "model = calm.CALM.from_pretrained('./gemma-ft', config=config)\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "model.to(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8X9xUltTH4vb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading Tokenizer...\n",
            "Prompting the model...\n",
            "Life is either a <strong>journey</strong> or a <strong>destination.</strong> If it's the former, you'll never arrive; if it's the latter, you'll never depart.\n",
            "\n",
            "- Anonymous\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "print('Loading Tokenizer...')\n",
        "from transformers import AutoTokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', use_fast=True)\n",
        "tokenizer.padding_side = 'right'\n",
        "\n",
        "print('Prompting the model...')\n",
        "prompt = \"Life is either a \"\n",
        "inputs = tokenizer(prompt, return_tensors='pt').to(device)\n",
        "outputs = model.generate(**inputs, max_new_tokens=40, use_cache=False,\n",
        "                         repetition_penalty=1.1)\n",
        "text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "print(text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uQBWnY-CEQ0k"
      },
      "source": [
        "You have successfully fine-tuned a CALM-composed model using the `gemma-2-2b` as both the anchor and augmentation models. While this demonstration focuses on a simple, small-scale example, the principles remain the same for larger models and datasets. By following this guide, you’ve learnt how to compose two Gemma models with CALM to create a new model that integrates capabilities from both, expanding its skills without incurring the computational overhead of a full re-training.\n",
        "\n",
        "### Next steps:\n",
        "- Experiment with different Gemma model variants or other instruction-tuned models.\n",
        "- Use larger datasets and more training steps for better model quality.\n",
        "- Adjust hyperparameters (e.g., learning rate, batch size, epochs) for optimal results.\n",
        "\n",
        "Happy fine-tuning!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]Finetune_with_CALM.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
